{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Afl-y-M6k1sk"
   },
   "source": [
    "# Tutorial of merging LiDAR data and extracting LiDAR features for WOMD\n",
    "\n",
    "This tutorial demonstrates how to add LiDAR data to the original WOMD scenes. It also provides methods to extract LiDAR points and features from the merged scenario proto message. Note that WOMD also provides APIs to load the camera data in the tutorial `tutorial_womd_camera.ipynb`.\n",
    "\n",
    "## Install\n",
    "\n",
    "To run Jupyter Notebook locally:\n",
    "\n",
    "```\n",
    "python3 -m pip install waymo-open-dataset-tf-2-12-0==1.6.7\n",
    "python3 -m pip install \"notebook>=5.3\" \"ipywidgets>=7.5\"\n",
    "python3 -m pip install --upgrade jupyter_http_over_ws>=0.0.7 && \\\n",
    "jupyter serverextension enable --py jupyter_http_over_ws\n",
    "jupyter notebook\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "03uEfb7cCodZ"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from typing import List, Dict, Tuple, Optional, Any\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "tf.enable_eager_execution()\n",
    "\n",
    "from waymo_open_dataset import dataset_pb2\n",
    "from waymo_open_dataset.protos import scenario_pb2\n",
    "from waymo_open_dataset.protos import compressed_lidar_pb2\n",
    "from waymo_open_dataset.utils import womd_lidar_utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Yt8jNg3IC3Bu"
   },
   "source": [
    "# Augmenting a WOMD scenario\n",
    "\n",
    "To augment the original WOMD with laser data for input frames, there are three steps:\n",
    "1. Load a scenario proto message, and check the `scenario_id` field.\n",
    "2. Find the corresponding compressed frame laser data file which has the file name as `{scenario_id}.tfrecord`.\n",
    "3. Load the frame laser data file which is a scenario proto with non-empty `compressed_frame_laser_data` field only and merge the loaded data into the scenario proto's `compressed_frame_laser_data` field."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6uCajVOzGFaC"
   },
   "outputs": [],
   "source": [
    "def _load_scenario_data(tfrecord_file: str) -> scenario_pb2.Scenario:\n",
    "  \"\"\"Load a scenario proto from a tfrecord dataset file.\"\"\"\n",
    "  dataset = tf.data.TFRecordDataset(tfrecord_file, compression_type='')\n",
    "  data = next(iter(dataset))\n",
    "  return scenario_pb2.Scenario.FromString(data.numpy())\n",
    "\n",
    "WOMD_FILE = '/content/waymo-od/tutorial/womd_scenario_input.tfrecord'\n",
    "womd_original_scenario = _load_scenario_data(WOMD_FILE)\n",
    "print(womd_original_scenario.scenario_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tfTb3ZoaGLCv"
   },
   "outputs": [],
   "source": [
    "# The corresponding compressed laser data file has the name\n",
    "# {scenario_id}.tfrecord. For simplicity, we rename the corresponding laser data\n",
    "# file 'ee519cf571686d19.tfrecord' to be 'womd_lidar_data.tfrecord'.\n",
    "LIDAR_DATA_FILE = '/content/waymo-od/tutorial/womd_lidar_and_camera_data.tfrecord'\n",
    "womd_lidar_scenario = _load_scenario_data(LIDAR_DATA_FILE)\n",
    "scenario_augmented = womd_lidar_utils.augment_womd_scenario_with_lidar_points(\n",
    "    womd_original_scenario, womd_lidar_scenario)\n",
    "print(len(scenario_augmented.compressed_frame_laser_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uzhDpVHFEEst"
   },
   "source": [
    "# Extract the lidar point clouds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3DXErUfxsCiE"
   },
   "outputs": [],
   "source": [
    "frame_points_xyz = {}  # map from frame indices to point clouds\n",
    "frame_points_feature = {}\n",
    "frame_i = 0\n",
    "\n",
    "def _get_laser_calib(\n",
    "    frame_lasers: compressed_lidar_pb2.CompressedFrameLaserData,\n",
    "    laser_name: dataset_pb2.LaserName.Name):\n",
    "  for laser_calib in frame_lasers.laser_calibrations:\n",
    "    if laser_calib.name == laser_name:\n",
    "      return laser_calib\n",
    "  return None\n",
    "\n",
    "# Extract point cloud xyz and features from each LiDAR and merge them for each\n",
    "# laser frame in the scenario proto.\n",
    "for frame_lasers in scenario_augmented.compressed_frame_laser_data:\n",
    "  points_xyz_list = []\n",
    "  points_feature_list = []\n",
    "  frame_pose = np.reshape(np.array(\n",
    "      scenario_augmented.compressed_frame_laser_data[frame_i].pose.transform),\n",
    "      (4, 4))\n",
    "  for laser in frame_lasers.lasers:\n",
    "    if laser.name == dataset_pb2.LaserName.TOP:\n",
    "      c = _get_laser_calib(frame_lasers, laser.name)\n",
    "      (points_xyz, points_feature,\n",
    "       points_xyz_return2,\n",
    "       points_feature_return2) = womd_lidar_utils.extract_top_lidar_points(\n",
    "           laser, frame_pose, c)\n",
    "    else:\n",
    "      c = _get_laser_calib(frame_lasers, laser.name)\n",
    "      (points_xyz, points_feature,\n",
    "       points_xyz_return2,\n",
    "       points_feature_return2) = womd_lidar_utils.extract_side_lidar_points(\n",
    "           laser, c)\n",
    "    points_xyz_list.append(points_xyz.numpy())\n",
    "    points_xyz_list.append(points_xyz_return2.numpy())\n",
    "    points_feature_list.append(points_feature.numpy())\n",
    "    points_feature_list.append(points_feature_return2.numpy())\n",
    "  frame_points_xyz[frame_i] = np.concatenate(points_xyz_list, axis=0)\n",
    "  frame_points_feature[frame_i] = np.concatenate(points_feature_list, axis=0)\n",
    "  frame_i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6IDe4OESuDEG"
   },
   "outputs": [],
   "source": [
    "print(frame_points_xyz[0].shape)\n",
    "print(frame_points_feature[0].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HINTDv91C1ES"
   },
   "source": [
    "###Show point cloud\n",
    "3D point clouds are rendered using an internal tool, which is unfortunately not publicly available yet. Here is an example of what they look like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZuV7dzzkyR4M"
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "display(Image('/content/waymo-od/tutorial/womd_point_cloud.png'))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "private_outputs": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
